import pandas as pd
import glob
from scipy.spatial.distance import euclidean
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import cosine


log_names_our = ["KGRU", "KPIV", "KPLI", "LMCT", "MMOO", "PQPU", "RQKO",
              "VJPP", "WMWV", "WZNT", "XSSP", "YFOK", "YLTB", "YYYX"]

log_names_graph = ["CVSL", "HHFX", "HRVO", "LFXF", "LHUB", "LWXT", "OMMY", "RVRM", "TVIX", "VOMC", "WDOW", "WRQT", "XMCA", "XWCN", "ZSLW"]
dataframes = {}

# log_names =  log_names_graph
# log_names = ["BGES"]
# log_names = ['GJZN', 'TNSH', 'WCIY', 'VUWJ', 'XMXK', 'TYJU', 'VGDN', 'SGUZ', 'MRVT', 'RZSG', 'VZRV', 'CZDG', 'FRDV', 'NUWW', 'XQFB', 'WZZB', 'AXSA', 'WRIS', 'SNZA', 'OSHM']
log_names = ['OFDP', 'FSSW', 'UGNN', 'ERYX']
log_names = ['VAJJ', 'HGUF', 'BQWM', 'RUTC', 'YRMS', 'PMBA', 'AIOT', 'GYRZ', 'QEUI', 'FMWP', 'LZCH', 'KBDG', 'VPBM', 'SHHA', 'ZLGC', 'NMBA', 'DAPV', 'STTW', 'IFCB', 'SKIT', 'IOFG', 'TTCD', 'LEPK', 'NAIR', 'UQGY', 'HCUO', 'QRSN', 'HVNJ']
for log_name in log_names:
    log_files = glob.glob("evaluation/logs/*.json")




    dataframes[log_name] = []
    for file in files:
        df = pd.read_csv(file)
        dataframes[log_name].append(df)

clean_dfs = {}
for game_id, dfs in dataframes.items():
    clean_dfs[game_id] = []
    for df in dfs:
        df.columns = df.iloc[0]
        df = df[1:]
        df = df.reset_index(drop=True)
        df = df.drop_duplicates(keep='last', subset=['round'], ignore_index=True)
        df = df.set_index('round')
        # df = df[['Minion-1', 'Minion-2']]
        clean_dfs[game_id].append(df.astype(float))

print(clean_dfs)
exit()
agreement_dfs = {}
for game_id , dfs in clean_dfs.items():
    result_df = pd.DataFrame()
    for i in range(len(dfs) - 1):
        for j in range(i + 1, len(dfs)):
            df1 = dfs[i]
            df2 = dfs[j]
            player1 = df1.columns[(df1 == 0).all()][0]
            player2 = df2.columns[(df2 == 0).all()][0]
            player1_num = player1.split('-')[1]
            player2_num = player2.split('-')[1]
            distances = df1.apply(lambda row: euclidean(row, df2.loc[row.name]), axis=1)
            # distances = df1.apply(lambda row: cosine(row, df2.loc[row.name]), axis=1)
            # print(f'Euclidean distances between dataframe {i} and dataframe {j} for game {game_id}:')
            # print(distances)
            result_df[f'({player1_num} , {player2_num})'] = distances
            # result_df[f'({i} , {j})'] = distances

    # result_df['average'] = result_df.mean(axis=1)
    agreement_dfs[game_id] = result_df

pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# print(agreement_dfs['KPIV'])
print(dataframes.keys())
# Reset the options to default after printing
pd.reset_option('display.max_rows')
pd.reset_option('display.max_columns')
pd.reset_option('display.width')
pd.reset_option('display.max_colwidth')

for game_id in log_names:
    # Plot the heatmap for the agreement dataframe of 'ZSLW'
    plt.figure(figsize=(12, 8))
    print(game_id)
    print(agreement_dfs[game_id])
    sns.heatmap(agreement_dfs[game_id], annot=True, cmap='coolwarm', fmt=".2f")
    plt.title(f'Agreement (Euclidean Distances) for Game {game_id} (sumproduct)')
    plt.xlabel('Player Pair')
    plt.ylabel('Round')
    plt.savefig(f'Archive 2/agreement_heatmap_{game_id}_evilonly_sum.png')

# plt.figure(figsize=(10, 6))
# for game_id, df in agreement_dfs.items():
#     plt.plot(df.index, df['average'], marker='o', linestyle='-', label=game_id)
# plt.title(f'Average Euclidean Distance')
# plt.xlabel('Round')
# plt.ylabel('Average Distance')
# plt.grid(True)
# plt.legend()
# plt.show()

# Combine all dataframes in agreement_dfs into a single dataframe
combined_df = pd.concat(agreement_dfs.values(), axis=1)

# Create a violin plot for each row to show the distribution of values
plt.figure(figsize=(12, 8))
sns.violinplot(data=combined_df.T, inner="quartile")
plt.title('Distribution of Average Distances Across Rounds sumproduct Agent')
plt.xlabel('Round')
plt.ylabel('Euclidean Distance')
plt.xticks(rotation=90)
plt.grid(True)
plt.show()

# print(clean_dfs['MMOO'][0].columns[(clean_dfs['MMOO'][0] == 0).all()])